{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1> Case Study: MOMENT for ECG Classification using PTB-XL, a large publicly available electrocardiography dataset</h1>\n",
    "<hr>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Contents\n",
    "### 1. PTB-XL dataset\n",
    "#### &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 1.1 Download PTB-XL dataset\n",
    "### 2. Loading MOMENT\n",
    "### 3. Method 1: Learning a Statistical ML Classifier on MOMENT Embeddings\n",
    "#### &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 3.1 Load PTB-XL dataset\n",
    "#### &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 3.2 Dignostic label Classification using raw ECG signal\n",
    "#### &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 3.3 Dignostic label Classification using MOMENT embedding on ECG signal\n",
    "### 4. Method 2: Finetuning Linear Classification Head\n",
    "### 5. Method 3: Full Finetuning MOMENT\n",
    "#### &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 5.1 Assess MOMENT embedding with SVM after finetuning the encoder\n",
    "#### &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 5.2 Training with Multi-GPU and Parameter Efficient FineTuning (PEFT)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Problem: ECG diagnosis classification with PTB-XL\n",
    "\n",
    "The PTB-XL ECG dataset contains 12-lead clinical ECG forms. The dataset also consists of 5 diagnosis classes assigned to each waveform. \n",
    "\n",
    "The classification problem is formulated to predict diagnosis class label given the patient ECG waveform. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1 Download PTB-XL dataset\n",
    "\n",
    "PTB-XL is avaliable on Physionet, and can be downloaded in [here](https://physionet.org/content/ptb-xl/1.0.3/)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Loading MOMENT\n",
    "\n",
    "We will first install the MOMENT package, load some essential packages and the pre-trained model. \n",
    "\n",
    "MOMENT can be loaded in 4 modes: (1) `reconstruction`, (2) `embedding`, (3) `forecasting`, and (4) `classification`.\n",
    "\n",
    "In the `reconstruction` mode, MOMENT reconstructs input time series, potentially containing missing values. We can solve imputation and anomaly detection problems in this mode. This mode is suitable for solving imputation and anomaly detection tasks. During pre-training, MOMENT is trained to predict the missing values within uniformly randomly masked patches (disjoint sub-sequences) of the input time series, leveraging information from observed data in other patches. As a result, MOMENT comes equipped with a pre-trained reconstruction head, enabling it to address imputation and anomaly detection challenges in a zero-shot manner! Check out the `anomaly_detection.ipynb` and `imputation.ipynb` notebooks for more details!\n",
    "\n",
    "In the `embedding` model, MOMENT learns a $d$-dimensional embedding (e.g., $d=1024$ for `MOMENT-1-large`) for each input time series. These embeddings can be used for clustering and classification. MOMENT can learn embeddings in a zero-shot setting! Check out `classification.ipynb` notebook for more details! \n",
    "\n",
    "The `forecasting` and `classification` modes are used for forecasting and classification tasks, respectively. In these modes, MOMENT learns representations which are subsequently mapped to the forecast horizon or the number of classes, using linear forecasting and classification heads. Both the forecasting and classification head are randomly initialized, and therefore must be fine-tuned before use. Check out the `forecasting.ipynb` notebook for more details!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install numpy pandas scikit-learn matplotlib tqdm\n",
    "!pip install git+https://github.com/moment-timeseries-foundation-model/moment.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from momentfm import MOMENTPipeline\n",
    "\n",
    "model = MOMENTPipeline.from_pretrained(\n",
    "                                        \"AutonLab/MOMENT-1-large\", \n",
    "                                        model_kwargs={\n",
    "                                            'task_name': 'classification',\n",
    "                                            'n_channels': 12,\n",
    "                                            'num_class': 5,\n",
    "                                            #disable gradient checkpointing to supress the warning when linear probing the model\n",
    "                                            #as MOMENT encoder is frozen\n",
    "                                            'enable_gradient_checkpointing': False,\n",
    "                                            #choose how embedding is obtained from the model\n",
    "                                            #if mean is chosen, the embedding is obtained by averaging the embedding along channel \n",
    "                                            #if concat is chosen, the embedding is obtained by concatenating the embedding along channel\n",
    "                                            #therefore concat would result in a larger embedding size but is much slower to train\n",
    "                                            'reduction': 'mean',\n",
    "                                        },\n",
    "                                        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.init()\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of parameters in the encoder\n",
    "num_params = sum(p.numel() for p in model.encoder.parameters())\n",
    "print(f\"Number of parameters: {num_params}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import os \n",
    "import torch \n",
    "import numpy as np \n",
    "\n",
    "def control_randomness(seed: int = 42):\n",
    "    random.seed(seed)\n",
    "    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "control_randomness(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Method 1: Learning a Statistical ML Classifier\n",
    "\n",
    "In this section, we would like illustrate MOMENT's representation power under zero-shot setting. To do this, we compare following experiment result using a Statistical Classifier, Support Vector Machine (SVM)\n",
    "\n",
    "1. Dignostic label Classification using raw ECG signal\n",
    "2. Dignostic label Classification using MOMENT embedding on raw ECG signal\n",
    "\n",
    "In this setting, we use MOMENT to embed time series data (see `representation_learning.ipynb`). Next, we train a Support Vector Machine (SVM) classifier using these embeddings as features and labels. This setting is common in field of unsupervised representation learning, where the goal is to learn meaningful time series representations without any labeled data (see [TS2Vec](https://arxiv.org/pdf/2106.10466) for a recent example). The quality of these representations are evaluated based on the performance of the downstream classifier (in this case, SVM). This is also the setting that we consider in our [paper](https://arxiv.org/abs/2402.03885). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.1 Load PTBXL dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once PTB-XL is downloaded, unzip it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from momentfm.data.ptbxl_classification_dataset import PTBXL_dataset\n",
    "import torch \n",
    "\n",
    "class Config:\n",
    "    #path to the unzipped PTB-XL dataset folder\n",
    "    base_path = \n",
    "\n",
    "    #path to cache directory to store preprocessed dataset if needed\n",
    "    #note that preprocessing the dataset is time consuming so you might be benefited to cache it\n",
    "    cache_dir = \n",
    "    load_cache = True\n",
    "\n",
    "    #sampling frequency, choose from 100 or 500\n",
    "    fs = 100\n",
    "\n",
    "    #class to predict\n",
    "    code_of_interest = 'diagnostic_class'\n",
    "    output_type = 'Single'\n",
    "\n",
    "    #sequence length, only support 512 for now\n",
    "    seq_len = 512\n",
    "\n",
    "args = Config()\n",
    "\n",
    "#create dataloader for training and testing\n",
    "train_dataset = PTBXL_dataset(args, phase='train')\n",
    "test_dataset = PTBXL_dataset(args, phase='test')\n",
    "val_dataset = PTBXL_dataset(args, phase='val')\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)\n",
    "val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 Dignostic label Classification using raw ECG signal\n",
    "\n",
    "In this setting, we concat raw ECG signal along the channel dimension, and feed the concatenated time-series directly into a SVM. The goal is to provide a baseline to assess MOMENT embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm \n",
    "import numpy as np \n",
    "\n",
    "def get_timeseries(dataloader: DataLoader, agg='mean'):\n",
    "    '''\n",
    "    We provide two aggregation methods to convert the 12-lead ECG (2-dimensional) to a 1-dimensional time-series for SVM training:\n",
    "    - mean: average over all channels, result in [1 x seq_len] for each time-series\n",
    "    - channel: concat all channels, result in [1 x seq_len * num_channels] for each time-series\n",
    "\n",
    "    labels: [num_samples]\n",
    "    ts: [num_samples x seq_len] or [num_samples x seq_len * num_channels]\n",
    "\n",
    "    *note that concat all channels will result in a much larger feature dimensionality, thus making the fitting process much slower\n",
    "    '''\n",
    "    ts, labels = [], []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch_x, batch_labels in tqdm(dataloader, total=len(dataloader)):\n",
    "            # [batch_size x 12 x 512]\n",
    "            if agg == 'mean':\n",
    "                batch_x = batch_x.mean(dim=1)\n",
    "                ts.append(batch_x.detach().cpu().numpy())\n",
    "            elif agg == 'channel':\n",
    "                ts.append(batch_x.view(batch_x.size(0), -1).detach().cpu().numpy())\n",
    "            labels.append(batch_labels)        \n",
    "\n",
    "    ts, labels = np.concatenate(ts), np.concatenate(labels)\n",
    "    return ts, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fit a SVM classifier on the concatenated raw ECG signals\n",
    "from momentfm.models.statistical_classifiers import fit_svm\n",
    "\n",
    "train_embeddings, train_labels = get_timeseries(train_loader, agg='mean')\n",
    "clf = fit_svm(features=train_embeddings, y=train_labels)\n",
    "train_accuracy = clf.score(train_embeddings, train_labels)\n",
    "\n",
    "test_embeddings, test_labels = get_timeseries(test_loader)\n",
    "test_accuracy = clf.score(test_embeddings, test_labels)\n",
    "\n",
    "print(f\"Train accuracy: {train_accuracy:.2f}\")\n",
    "print(f\"Test accuracy: {test_accuracy:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3 Dignostic label Classification using MOMENT embedding on ECG signal\n",
    "\n",
    "In this setting, we use MOMENT to embed time series data (see `representation_learning.ipynb`). Next, we train a Support Vector Machine (SVM) classifier using these embeddings as features and labels. This setting is common in field of unsupervised representation learning, where the goal is to learn meaningful time series representations without any labeled data (see [TS2Vec](https://arxiv.org/pdf/2106.10466) for a recent example). The quality of these representations are evaluated based on the performance of the downstream classifier (in this case, SVM). This is also the setting that we consider in our [paper](https://arxiv.org/abs/2402.03885). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm \n",
    "import numpy as np \n",
    "from momentfm.models.statistical_classifiers import fit_svm\n",
    "\n",
    "def get_embeddings(model, device, reduction, dataloader: DataLoader):\n",
    "    '''\n",
    "    labels: [num_samples]\n",
    "    embeddings: [num_samples x d_model]\n",
    "    '''\n",
    "    embeddings, labels = [], []\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch_x, batch_labels in tqdm(dataloader, total=len(dataloader)):\n",
    "            # [batch_size x 12 x 512]\n",
    "            batch_x = batch_x.to(device).float()\n",
    "            # [batch_size x num_patches x d_model (=1024)]\n",
    "            output = model(batch_x, reduction=reduction) \n",
    "            #mean over patches dimension, [batch_size x d_model]\n",
    "            embedding = output.embeddings.mean(dim=1)\n",
    "            embeddings.append(embedding.detach().cpu().numpy())\n",
    "            labels.append(batch_labels)        \n",
    "\n",
    "    embeddings, labels = np.concatenate(embeddings), np.concatenate(labels)\n",
    "    return embeddings, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#set device to be 'cuda:0' or 'cuda' if you only have one GPU\n",
    "device = 'cuda:6'\n",
    "reduction = 'mean'\n",
    "train_embeddings, train_labels = get_embeddings(model, device, reduction, train_loader)\n",
    "clf = fit_svm(features=train_embeddings, y=train_labels)\n",
    "train_accuracy = clf.score(train_embeddings, train_labels)\n",
    "\n",
    "test_embeddings, test_labels = get_embeddings(model, device, reduction, test_loader)\n",
    "test_accuracy = clf.score(test_embeddings, test_labels)\n",
    "\n",
    "print(f\"Train accuracy: {train_accuracy:.2f}\")\n",
    "print(f\"Test accuracy: {test_accuracy:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We saw that MOMENT-extracted embedding improves test time accuracy from 60% to 76%! Note that PTB-XL ECG signals does NOT appear in MOMENT pretraining data. This performance improvement shows MOMENT's high quality representation generation ability under zero shot setting."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Method 2: Finetuning the Linear Classification Head only\n",
    "\n",
    "In this setting, we freeze the MOMENT encoder and finetune the linear classification head using Cross Entropy Loss. MOMENT encoder is frozen by default."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epoch(model, device, train_dataloader, criterion, optimizer, scheduler, reduction='mean'):\n",
    "    '''\n",
    "    Train only classification head\n",
    "    '''\n",
    "    model.to(device)\n",
    "    model.train()\n",
    "    losses = []\n",
    "\n",
    "    for batch_x, batch_labels in train_dataloader:\n",
    "        optimizer.zero_grad()\n",
    "        batch_x = batch_x.to(device).float()\n",
    "        batch_labels = batch_labels.to(device)\n",
    "\n",
    "        #note that since MOMENT encoder is based on T5, it might experiences numerical unstable issue with float16\n",
    "        with torch.autocast(device_type='cuda', dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float32):\n",
    "            output = model(batch_x, reduction=reduction)\n",
    "            loss = criterion(output.logits, batch_labels)\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        scheduler.step()\n",
    "        losses.append(loss.item())\n",
    "    \n",
    "    avg_loss = np.mean(losses)\n",
    "    return avg_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_epoch(dataloader, model, criterion, device, phase='val', reduction='mean'):\n",
    "    model.eval()\n",
    "    model.to(device)\n",
    "    total_loss, total_correct = 0, 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch_x, batch_labels in dataloader:\n",
    "            batch_x = batch_x.to(device).float()\n",
    "            batch_labels = batch_labels.to(device)\n",
    "\n",
    "            output = model(batch_x, reduction=reduction)\n",
    "            loss = criterion(output.logits, batch_labels)\n",
    "            total_loss += loss.item()\n",
    "            total_correct += (output.logits.argmax(dim=1) == batch_labels).sum().item()\n",
    "    \n",
    "    avg_loss = total_loss / len(dataloader)\n",
    "    accuracy = total_correct / len(dataloader.dataset)\n",
    "    return avg_loss, accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import numpy as np \n",
    "\n",
    "epoch = 5\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.AdamW(model.head.parameters(), lr=1e-4)\n",
    "scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=epoch * len(train_loader))\n",
    "device = 'cuda:3'\n",
    "\n",
    "for i in tqdm(range(epoch)):\n",
    "    train_loss = train_epoch(model, device, train_loader, criterion, optimizer, scheduler)\n",
    "    val_loss, val_accuracy = evaluate_epoch(val_loader, model, criterion, device, phase='test')\n",
    "    print(f'Epoch {i}, train loss: {train_loss}, val loss: {val_loss}, val accuracy: {val_accuracy}')\n",
    "\n",
    "test_loss, test_accuracy = evaluate_epoch(test_loader, model, criterion, device, phase='test')\n",
    "print(f'Test loss: {test_loss}, test accuracy: {test_accuracy}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Method 3: Full Finetuning MOMENT\n",
    "\n",
    "In this section, we unfreeze MOMENT encoder and finetune the full model on PTB-XL dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#loading MOMENT with encoder unfrozen\n",
    "from momentfm import MOMENTPipeline\n",
    "\n",
    "model = MOMENTPipeline.from_pretrained(\n",
    "                                        \"AutonLab/MOMENT-1-large\", \n",
    "                                        model_kwargs={\n",
    "                                            'task_name': 'classification',\n",
    "                                            'n_channels': 12,\n",
    "                                            'num_class': 5,\n",
    "                                            'freeze_encoder': False,\n",
    "                                            'freeze_embedder': False,\n",
    "                                            'reduction': 'mean',\n",
    "                                        },\n",
    "                                        )\n",
    "model.init()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the learning rate should be smaller to guide the encoder to learn the task without forgetting the pre-trained knowledge\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "epoch = 5\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)\n",
    "scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-4, total_steps=epoch * len(train_loader))\n",
    "device = 'cuda:3'\n",
    "\n",
    "for i in tqdm(range(epoch)):\n",
    "    train_loss = train_epoch(model, device, train_loader, criterion, optimizer, scheduler)\n",
    "    val_loss, val_accuracy = evaluate_epoch(val_loader, model, criterion, device, phase='test')\n",
    "    print(f'Epoch {i}, train loss: {train_loss}, val loss: {val_loss}, val accuracy: {val_accuracy}')\n",
    "\n",
    "test_loss, test_accuracy = evaluate_epoch(test_loader, model, criterion, device, phase='test')\n",
    "print(f'Test loss: {test_loss}, test accuracy: {test_accuracy}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.1 Assess MOMENT embedding with SVM after finetuning the encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#set device to be 'cuda:0' or 'cuda' if you only have one GPU\n",
    "device = 'cuda:3'\n",
    "reduction = 'mean'\n",
    "train_embeddings, train_labels = get_embeddings(model, device, reduction, train_loader)\n",
    "clf = fit_svm(features=train_embeddings, y=train_labels)\n",
    "train_accuracy = clf.score(train_embeddings, train_labels)\n",
    "\n",
    "test_embeddings, test_labels = get_embeddings(model, device, reduction, test_loader)\n",
    "test_accuracy = clf.score(test_embeddings, test_labels)\n",
    "\n",
    "print(f\"Train accuracy: {train_accuracy:.2f}\")\n",
    "print(f\"Test accuracy: {test_accuracy:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We saw that after MOMENT encoder is finetuned for downstream dataset, the embedding gives better test accuracy with SVM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 Training with Multi-GPU and Parameter Efficient FineTuning (PEFT)\n",
    "\n",
    "It might be of interest to the research community with an example to train MOMENT with multi-gpu and PEFT approaches. We also offer a script where this could be achieved.\n",
    "\n",
    "Note that number of processes should be adjusted in the config file at in finetune_demo/ds.ymal according to your setup."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!CUDA_VISIBLE_DEVICES=3,4 accelerate launch --config_file tutorials/finetune_demo/ds.yaml \\\n",
    "    tutorials/finetune_demo/classification.py \\\n",
    "    --base_path path to your ptbxl base folder \\\n",
    "    --cache_dir path to cache directory for preprocessed dataset \\\n",
    "    --mode full_finetuning \\\n",
    "    --output_path path to store train log and checkpoint \\"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code also supports [LoRA](https://arxiv.org/abs/2106.09685) as a way of doing parameter efficient finetuning. To use LoRA, simply add a flag to the command line above. Currently, LoRA doesn't work well with deepspeed zero3, therefore one might consider switching to stage 2 for LoRA in finetune_demo/ds.ymal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!CUDA_VISIBLE_DEVICES=3,4 accelerate launch --config_file tutorials/finetune_demo/ds.yaml \\\n",
    "    tutorials/finetune_demo/classification.py \\\n",
    "    --base_path path to your ptbxl base folder \\\n",
    "    --cache_dir path to cache directory for preprocessed dataset \\\n",
    "    --mode full_finetuning \\\n",
    "    --output_path path to store train log and checkpoint \\\n",
    "    --lora"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
